# -*- coding: utf-8 -*-
# %%
import numpy as np
import scipy.ndimage as ndimage
import scipy.io as sio
import matplotlib.pyplot as plt
import os.path
import matplotlib.patches as patches
import time
import sys
import h5py


# %%
'''def atomcenter_pos():
    atom_center = [np.array([208, 276, 0, 0]), np.array([257, 265, 0, 0]), np.array([307, 254, 0, 0]),
                   np.array([318, 305, 0, 0]), np.array([329, 354, 0, 0]), np.array([280, 366, 0, 0]),
                   np.array([230, 377, 0, 0]), np.array([217, 327, 0, 0]), np.array([269, 316, 0, 0]),
                   np.array([279, 107, 0, 0]), np.array([329, 94, 0, 0]), np.array([381, 83, 0, 0]),
                   np.array([392, 136, 0, 0]), np.array([404, 186, 0, 0]), np.array([354, 196, 0, 0]),
                   np.array([302, 209, 0, 0]), np.array([290, 155, 0, 0]), np.array([342, 147, 0, 0])]
    return atom_center'''


# %%
def im_rot_copy(image, cut_width, rot_copy_switch, cut_switch, width_min):
    # AKA imRotCopy.m
    if cut_switch == 1:
        cut_x, cut_y = np.meshgrid(np.arange(0, cut_width), np.arange(0, cut_width))
        cut_map_0 = np.zeros((cut_width, cut_width))
        fbz_center_x = np.ceil(cut_width / 2)
        fbz_center_y = np.ceil(cut_width / 2)
        circle_radial = fbz_center_x - width_min
        cut_range = np.sqrt(
            (cut_x - fbz_center_x) * (cut_x - fbz_center_x) + (cut_y - fbz_center_x) * (cut_y - fbz_center_y))
        cut_map_0[cut_range < circle_radial] = 1
        cut_map_1 = np.zeros(cut_width, cut_width)
        cut_map_2 = np.zeros(cut_width, cut_width)
        '''
        cut_map_0 is a matrix of dim cut_width, where the points in the circle of R=circle_radial at the center is 1,
        and others are 0.
        '''

        if rot_copy_switch == 1:
            cut_map_1[(cut_range > circle_radial) and (cut_x <= fbz_center_x) and (cut_y <= fbz_center_y)] = 1
            cut_map_1[(cut_range > circle_radial) and (cut_x >= fbz_center_x) and (cut_y >= fbz_center_y)] = 1
        cut_map_1 = cut_map_1 * image
        cut_map_2 = cut_map_2 * image
        cut_map_1_rot90 = ndimage.rotate(cut_map_1, 90)
        cut_map_2_rot90 = ndimage.rotate(cut_map_2, 90)
        cut_map = cut_map_0 * image + cut_map_1 + cut_map_1_rot90 + cut_map_2 + cut_map_2_rot90
    else:
        cut_map = image
    return cut_map


# %%
def find_atom_contour(data, threshold, figure_index):
    data_b = data > threshold
    data_cold = data * data_b
    #    fig = plt.figure(Figure_Index)
    #    plt.subplot(121); plt.imshow(Data, vmin=-0.1, vmax=0.5, cmap='jet')
    #    plt.subplot(122); plt.imshow(Data_Cold, vmin=-0.1, vmax=0.5, cmap='jet')
    #    plt.title('Atom: '+str(Figure_Index))
    #    plt.pause(0.5)
    #    plt.close()
    return data_cold


# %%
def cal_atomsize(od_image, threshold):
    """
    If the sum of the od_image is smaller than threshold[0], set the size to 0.
    If none of the points in od_image is greater than threshold[1], set the size to 15.
    Otherwise, the size is the average radius of those greater than threshold[1].
    """
    if np.sum(od_image) < threshold[0]:
        d = 0
    else:
        row, col = np.where(od_image > threshold[1])
        if np.size(row) == 0:
            d = 16
        else:
            size_row = np.max(row) - np.min(row)
            size_col = np.max(col) - np.min(col)
            d = (size_row + size_col) / 2
            d = max((d, 16))
    return d


def is_bec(od_image, atom_size, threshold):
    '''
    Use Fi to 'label' atoms. 
    If BECT=Od_Max/Atom_Size>T[0], Fi=0.
    If T[0]<BECT<T[1], Fi=0.5.
    If BECT>T[1], Fi=1.
    For different label, use different starting parameters for fitting.
    '''
    Od_Max = np.max(od_image)
    BECT = Od_Max / atom_size
    if BECT <= threshold[0]:
        Fi = 0
    elif BECT <= threshold[1] and BECT > threshold[0]:
        Fi = 0.5
    elif BECT > threshold[1]:
        Fi = 1
    return Fi


# %%
def Atom_N1(Data_Ori, Center_Pos, Width, Background, Figure_Index, \
            Matrix_Shift, width_min, rot_copy_switch, \
            cut_switch, Mode, find_atom_contour_threshold, **kw):
    Center_Pos = np.ceil(Center_Pos)
    # Width_Mult = 3.2
    if Mode == 0 or Mode == 4 or Mode == 5 or Mode == 100:
        if type(Width) == list:
            Start_x = int(Center_Pos[0] - np.floor(Width[0] / 2)) + 1
            Start_y = int(Center_Pos[1] - np.floor(Width[1] / 2)) + 1
            Vector_x = np.arange(Start_x, Start_x + Width[0]).reshape((1, Width[0]))
            Vector_y = np.arange(Start_y, Start_y + Width[1]).reshape((Width[1], 1))
        else:
            Start_x = int(Center_Pos[0] - np.floor(Width / 2)) + 1
            Start_y = int(Center_Pos[1] - np.floor(Width / 2)) + 1
            Vector_x = np.arange(Start_x, Start_x + Width).reshape((1, Width))
            Vector_y = np.arange(Start_y, Start_y + Width).reshape((Width, 1))
        Data_Cut = Data_Ori[Vector_y, Vector_x]
        if 'Circle_Cut' in kw and kw['Circle_Cut'] == 1:
            Data_Cut = im_rot_copy(Data_Cut, Width, rot_copy_switch, cut_switch, width_min)

        if Mode == 4:
            Data_Cut[Data_Cut < Background] = 0
            OD_Sum = np.sum(Data_Cut)
        elif Mode == 0:
            # Data_Cut[Data_Cut == 0] = Background
            Diff_Data_Cut = Data_Cut - Background
            # Diff_Data_Cut[Diff_Data_Cut < -0.005] = 1e-8
            #            Data_Cold = find_atom_contour(Diff_Data_Cut, find_atom_contour_threshold, Figure_Index)
            OD_Sum = np.sum(Diff_Data_Cut)
        elif Mode == 100:
            Diff_Data_Cut = Data_Cut - Background
            # Diff_Data_Cut[Diff_Data_Cut < -0.005] = 1e-8
            #            Data_Cold = find_atom_contour(Diff_Data_Cut, find_atom_contour_threshold, Figure_Index)
            OD_Sum = np.sum(Diff_Data_Cut)
    return OD_Sum, Diff_Data_Cut


# %%
def Save_And_Load_Data(Load_Name, Save_Name, File_Mode, **kw):
    '''
    Save_Name entered without extension. If neither Save_Name.mat nor Save_Name.npy exists,
    load file from Load_Name. If either one exists, load it.
    File_Mode 0, raw pic; File_Mode 1, removed background with Matlab; 
    File_Mode 2, removed background with python.
    '''
    if os.path.exists(Save_Name + '.mat') == 0 and os.path.exists(Save_Name + '.npy') == 0:
        if File_Mode == 0:
            pass
        elif File_Mode == 1:
            if 'MatVersion' in kw and kw['MatVersion'] == '7.3':
                f = h5py.File(Load_Name, 'r')
                od_image = np.array(f['odimage'])
                od_image = od_image.T
            else:
                Image_Ori = sio.loadmat(Load_Name)
                od_image = Image_Ori['odimage']
        elif File_Mode == 2:
            Image_Ori = np.load(Load_Name)
            od_image = Image_Ori['odimage']
        print('Calculating ' + Load_Name)
        np.save(Save_Name + '.npy', od_image)
        sio.savemat(Save_Name + '.mat', {'odimage': od_image})
    elif os.path.exists(Save_Name + '.mat') == 1:
        if 'MatVersion' in kw and kw['MatVersion'] == '7.3':
            f = h5py.File(Load_Name, 'r')
            od_image = np.array(f['odimage'])
            od_image = od_image.T
        else:
            Image_Ori = sio.loadmat(Save_Name)
            od_image = Image_Ori['odimage']

        print('Calculating ' + Save_Name)
        if os.path.exists(Save_Name + '.npy') == 0:
            try:
                np.save(Save_Name + '.npy', od_image)
            except PermissionError:
                pass
    elif os.path.exists(Save_Name + '.npy') == 1:
        Image_Ori = np.load(Save_Name)
        od_image = Image_Ori['odimage']
        print('Calculating ' + Save_Name)
        if os.path.exists(Save_Name + '.mat') == 0:
            sio.savemat(Save_Name + '.mat', {'odimage': od_image})
    return od_image


# def Draw_Atoms(Data, Photo_Ind, x_range, y_range, SaveName):
#     fig = plt.figure(Photo_Ind)
#     # plt.imshow(Data_Ori, vmin=-0.1, vmax=0.5, cmap='jet')
#     plt.imshow(Data_Ori[170:350, 150:330], vmin=-0.1, vmax=0.5, cmap='jet')
#     fig.savefig(SaveName+'.png')
#     fig.savefig(SaveName+'.pdf')
#     plt.pause(0.5)
#     plt.close()
# %%
def Load_Pic_And_Cal(Photo_Ind, Load_Name, File_Mode, Save_Name, SaveFig_Address, Atom_Center, cut_width, Cloud_Number,
                     OriImage_Switch, **kw):
    # OriImage_Switch = 1
    # FitDraw_Switch = 0
    Wait_Sec = 1
    Atom_Mode = 0
    find_atom_contour_threshold = 0.05
    Background_Center = Atom_Center[-1]
    Background_Wid = cut_width[-1]
    Matrix_Shift = np.array([0, 0])
    while (True):
        #        A_Flag = os.path.exists(Load_Name) or os.path.exists(Save_Name+'mat') or os.path.exists(Save_Name+'npy')
        try:
            Data_Ori = Save_And_Load_Data(Load_Name, Save_Name, File_Mode)
            break
        except FileNotFoundError:
            time.sleep(Wait_Sec)
        except OSError:
            time.sleep(Wait_Sec)
        except ValueError:
            time.sleep(Wait_Sec)
        except NotImplementedError:
            Data_Ori = Save_And_Load_Data(Load_Name, Save_Name, File_Mode, MatVersion='7.3')
            break
        except:
            time.sleep(Wait_Sec)
    #            return 'Error', Photo_Ind, sys.exc_info()
    if 'Rot_Ang' in kw:
        Data_Ori = ndimage.rotate(Data_Ori, kw['Rot_Ang'], reshape=False)
    x_range = [0, np.shape(Data_Ori)[0]]
    y_range = [0, np.shape(Data_Ori)[1]]
    if 'Cut_Aera' in kw:
        x_range = kw['Cut_Aera'][0]
        y_range = kw['Cut_Aera'][1]
    y_Start = y_range[0]
    y_End = y_range[1]
    x_Start = x_range[0]
    x_End = x_range[1]
    Region_All = Data_Ori[y_Start:y_End, x_Start:x_End].copy()
    if OriImage_Switch == 1:
        fig, ax = plt.subplots(1)

        plt.imshow(Region_All, vmin=-0.1, vmax=0.5, cmap='jet')
        theta = np.linspace(-np.pi, np.pi, 1000)
        for Atom_Ind in range(len(Atom_Center) - 1):
            if 'Enclose' in kw and kw['Enclose'][Atom_Ind] == '0':
                break
            if 'Enclose' in kw and kw['Enclose'][Atom_Ind] == 'Rectangle':
                if type(cut_width[Atom_Ind]) == list:
                    rect = patches.Rectangle((Atom_Center[Atom_Ind][0] - x_Start - cut_width[Atom_Ind][0] / 2,
                                              Atom_Center[Atom_Ind][1] - y_Start - cut_width[Atom_Ind][1] / 2),
                                             cut_width[Atom_Ind][0], cut_width[Atom_Ind][1], linewidth=1, edgecolor='r',
                                             facecolor='none')
                else:
                    rect = patches.Rectangle((Atom_Center[Atom_Ind][0] - x_Start - cut_width[Atom_Ind] / 2,
                                              Atom_Center[Atom_Ind][1] - y_Start - cut_width[Atom_Ind] / 2),
                                             cut_width[Atom_Ind], cut_width[Atom_Ind], linewidth=1, edgecolor='r',
                                             facecolor='none')
                ax.add_patch(rect)
                continue
            if 'Enclose' in kw and kw['Enclose'][Atom_Ind] == 'Circle':
                if type(cut_width[Atom_Ind]) == list:
                    circlex = cut_width[Atom_Ind][0] / 2 * np.cos(theta) + Atom_Center[Atom_Ind][0] - x_Start
                    circley = cut_width[Atom_Ind][1] / 2 * np.sin(theta) + Atom_Center[Atom_Ind][1] - y_Start
                else:
                    circlex = cut_width[Atom_Ind] / 2 * np.cos(theta) + Atom_Center[Atom_Ind][0] - x_Start
                    circley = cut_width[Atom_Ind] / 2 * np.sin(theta) + Atom_Center[Atom_Ind][1] - y_Start
                plt.plot(circlex, circley, 'k')
        plt.plot()
        plt.text(75, 45, 'photo%d' % Photo_Ind, color='w')
        if 'Axis' in kw and kw['Axis'] == 'off':
            plt.axis('off')
        fig.savefig(SaveFig_Address + str(Photo_Ind) + r'Ori_Image' + '.png')
        #        fig.savefig(SaveFig_Address+str(Photo_Ind)+r'Ori_Image'+'.pdf')
        plt.pause(0.5)
    #        plt.close()
    Background = Atom_N1(Data_Ori, Background_Center, Background_Wid, 0, 50, \
                         Matrix_Shift, 0, 0, 0, 100, 0)[0] / (Background_Wid * Background_Wid)

    # Atom_Nums = np.zeros([18, 1])
    Atom_Nums = []
    Atom_Regions = []

    Atom_Ind = 0
    #    plt.close('all')
    #    plt.imshow(Data_Ori, vmin=-0.1, vmax=0.5, cmap='jet')
    while (Atom_Ind < Cloud_Number):
        Atom_Num, Atom_Region = Atom_N1(Data_Ori, Atom_Center[Atom_Ind], cut_width[Atom_Ind], \
                                        Background, Atom_Ind, Matrix_Shift, 0, 0, 1, Atom_Mode,
                                        find_atom_contour_threshold)
        if 'Enclose' in kw and kw['Enclose'][Atom_Ind] == 'Circle':
            Atom_Num, Atom_Region = Atom_N1(Data_Ori, Atom_Center[Atom_Ind], cut_width[Atom_Ind], \
                                            Background, Atom_Ind, Matrix_Shift, 0, 0, 1, Atom_Mode,
                                            find_atom_contour_threshold, Circle_Cut=1)
        Atom_Num *= 112.2
        Atom_Nums.append(Atom_Num)
        Atom_Regions.append(Atom_Region)
        Atom_Ind += 1

    # if FitBEC_Switch == 1:
    #     IsBEC_threshold = np.array([0.015, 0.035])
    #     Od_Low = 0.3
    #     Size_8 = np.shape(Atom_Region_8)
    #     Atom_Size_8 = cal_atomsize(Atom_Region_8, [50, 0.17])
    #     Fi_8 = Is_BEC(Atom_Region_8, Atom_Size_8, IsBEC_threshold)
    #     if Fi_8 == 1:
    #         Od_Lim_8 = 0.3
    #         Ini_Size_8 = 100
    #         Low_Size_8 = 50
    #         x1_8 = 3
    #         x7_8 = x1_8
    #         x1u_8 = 10
    #         x7u_8 = x1u_8
    #     elif Fi_8 == 0.5:
    #         Od_Lim_8 = np.max(Atom_Region_8)*0.5
    #         Od_Lim_8[Od_Lim_8 < Od_Low] = Od_Low
    #         Ini_Size_8 = Atom_Size_8
    #         Low_Size_8 = 10
    #         Ini_Size_8[Low_Size_8 > Ini_Size_8] = Low_Size_8+1
    #         x1_8 = Od_Lim_8/2
    #         x7_8 = x1_8
    #         x1u_8 = Od_Lim_8
    #         x7u_8 = x1u_8
    #     else:
    #         Od_Lim_8 = 5
    #         Ini_Size_8 = Atom_Size_8
    #         Low_Size_8 = 0
    #         x1_8 = Od_Lim_8/2
    #         x7_8 = x1_8
    #         x1u_8 = Od_Lim_8
    #         x7u_8 = x1u_8
    #     x_8 = np.array([x1_8, 0, Ini_Size_8, 0, Ini_Size_8, 0, x7_8, 0])
    #     x_Lim_8 = [[0, -Size_8[0]/2, Low_Size_8, -Size_8[1]/2, Low_Size_8, -np.pi/4, 0, -1], \
    #     [x1u_8, Size_8[0]/2, Size_8[1]*Size_8[1]/4, Size_8[0]/2, Size_8[1]*Size_8[1]/4, np.pi/4, x7u_8, 1]]

    Atom_Nums = np.array(Atom_Nums)
    Atom_Ratio = Atom_Nums / np.sum(Atom_Nums)
    Atom_Ratio = Atom_Ratio.tolist()
    return Atom_Ratio, Atom_Regions, Region_All


def Polylog(z, s, epsilon):
    g = 0
    n = 1
    while (np.power(z, n) / np.power(n, s) > epsilon):
        g += np.power(z, n) / np.power(n, s)
    return g
